if __name__ == "__main__":
    from multiagentenv import MultiAgentEnv
else:
    from .multiagentenv import MultiAgentEnv
# from utils.dict2namedtuple import convert # This was breaking locally, so define manually
from collections import namedtuple
import numpy as np
import random
import itertools

ACTIONS = list(range(5))
LEFT, RIGHT, UP, DOWN, NOOP = ACTIONS

def convert(dictionary):
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)

class CMaze(MultiAgentEnv):
    def __init__(self, **kwargs):

        args = kwargs["env_args"]
        if isinstance(args, dict):
            args = convert(args)

        self.n_agents = 2

        self.lengths = tuple(int(l) for l in args.lengths)
        self.starts = tuple(int(l) for l in args.starts)
        self.hint_locs = tuple(None if l is None else int(l) for l in args.hint_locs)
        self.button_loc = int(args.button_loc)
        self.episode_limit = int(args.episode_limit)
        self.right_r = float(args.right_r)
        self.wrong_r = float(args.wrong_r)
        self.num_decoy_buttons = int(args.num_decoy_buttons)
        if(self.num_decoy_buttons > 0):
            # All positions along the length[0] are possible decoy positions except the starting one
            self.possible_decoy_pos = [i for i in range(1, self.lengths[0] + 1) if i != self.starts[0]]
            if(self.num_decoy_buttons > len(self.possible_decoy_pos)):
                print("The number of decoy buttons cannot be greater than the length of the corridor - 1")
                exit()
            self._gen_decoy_buttons_locs()
            print(self.decoy_button_locs)
        # If fixed, decoy buttons' locations would not reinitialized if reset
        self.decoy_buttons_fixed = int(args.decoy_buttons_fixed)
        self.decoy_buttons_same_hint = int(args.decoy_buttons_same_hint)

        # set in reset:
        self.timestep = None
        self.agent0_loc = [None, None]
        self.agent1_loc = [None, None]
        self.button_pressed = None
        self.reset()

    def reset(self):
        self.timestep = 0
        self.agent0_loc = [self.starts[0], 0]
        self.agent1_loc = [self.starts[1], 0]
        self.hint = np.random.choice([UP, DOWN])
        self.button_pressed = False
        if(self.num_decoy_buttons > 0):
            if(self.decoy_buttons_fixed):
                self._gen_decoy_buttons_locs()
        return self.get_obs(), self.get_state()

    def state_transition(self, index, loc, action):
        x, y = loc
        # agent can move up or down with decoy buttons
        if(self.num_decoy_buttons == 0):
            if y != 0: # moved up or down
                return loc # Stuck!

        length = self.lengths[index]

        new_x, new_y = x, y
        if action == LEFT:
            if(y == 0):
                new_x, new_y = (x-1, y)
            elif(index == 0 and (x, y) in self.decoy_button_locs):
                new_x, new_y = (x-1, y)
            elif(index == 0 and self.decoy_buttons_same_hint and (x, y) in self.decay_hint_locs):
                button_pressed_index = self.decay_hint_locs.index((x, y))
                if(self.decoy_buttons_pressed[button_pressed_index]):
                    new_x, new_y = (x-1, y)
        elif action == RIGHT:
            if(y == 0):
                new_x, new_y = (x+1, y)
            elif(index == 0 and (x+2, y) in self.decoy_button_locs):
                new_x, new_y = (x+1, y)
            elif(index == 0 and self.decoy_buttons_same_hint and (x+2, y) in self.decay_hint_locs):
                button_pressed_index = self.decay_hint_locs.index((x+2, y))
                if(self.decoy_buttons_pressed[button_pressed_index]):
                    new_x, new_y = (x+1, y)
        elif action == UP:
            if index == 1 and x == length-1: # can only go up at the end and if agent 1
                new_x, new_y = (x, y+1)
            if(index == 0 and self.num_decoy_buttons > 0):
                if((x+1, y+1) in self.decoy_button_locs):
                    # Moving to decoy buttons
                    new_x, new_y = (x, y+1)
                elif(self.decoy_buttons_same_hint and (x+1, y+1) in self.decay_hint_locs):
                    button_pressed_index = self.decay_hint_locs.index((x+1, y+1))
                    if(self.decoy_buttons_pressed[button_pressed_index]):
                        new_x, new_y = (x, y+1)
                elif(y+1 == 0):
                    # Moving away from decoy buttons
                    new_x, new_y = (x, y+1)
        elif action == DOWN:
            if index == 1 and x == length-1: # can only go up at the end and if agent 1
                new_x, new_y = (x, y-1)
            if(index == 0 and self.num_decoy_buttons > 0):
                if(((x+1, y-1) in self.decoy_button_locs)):
                    new_x, new_y = (x, y-1)
                elif(self.decoy_buttons_same_hint and (x+1, y-1) in self.decay_hint_locs):
                    button_pressed_index = self.decay_hint_locs.index((x+1, y-1))
                    if(self.decoy_buttons_pressed[button_pressed_index]):
                        new_x, new_y = (x, y-1)
                elif(y-1 == 0):
                    new_x, new_y = (x, y-1)
        else:
            assert action == NOOP, (action, NOOP)

        new_x = np.clip(new_x, 0, length-1)
        new_y = np.clip(new_y, -2, 2)

        if index == 0 and x == self.button_loc: # button obs for agent 0
            self.button_pressed = True
        if(index == 0 and self.num_decoy_buttons > 0 and (x+1, y) in self.decoy_button_locs):
            self.decoy_buttons_pressed[self.decoy_button_locs.index((x+1, y))] = True
        # print(str(new_x) + ", " + str(new_y))
        # print(self.decoy_button_locs)
        return new_x, new_y

    def step(self, action):
        self.agent0_loc = self.state_transition(0, self.agent0_loc, action[0])
        self.agent1_loc = self.state_transition(1, self.agent1_loc, action[1])
        a1_x, a1_y = self.agent1_loc

        self.timestep += 1
        done = (self.timestep >= self.episode_limit) or a1_y != 0 # Could also remove y!=0 for repeat reward until horizon reached

        # compute reward
        reward = 0
        if a1_y == 1:
            reward = self.right_r if self.hint == UP else self.wrong_r
        if a1_y == -1:
            reward = self.right_r if self.hint == DOWN else self.wrong_r

        if(self.num_decoy_buttons == 0):
            info = {'hint': 1 if self.hint == UP else 0, "button_pressed": 1 if self.button_pressed else 0}
        else:
            info = {'hint': 1 if self.hint == UP else 0, "button_pressed": 1 if self.button_pressed else 0, "decoy_buttons_pressed": [1 if flag else 0 for flag in self.decoy_buttons_pressed]}

        return reward, done, info

    def _hint_encoding(self):
        return [1] if self.hint == UP else [-1]

    def _gen_decoy_buttons_locs(self):
        num_up_isle = random.randint(0, self.num_decoy_buttons)
        num_down_isle = self.num_decoy_buttons - num_up_isle
        db_y_list = [-1] * num_up_isle + [1] * num_down_isle
        random.shuffle(db_y_list)
        db_x_list = random.sample(self.possible_decoy_pos, self.num_decoy_buttons)
        self._db_x_list = db_x_list
        self._db_y_list = db_y_list
        self.decoy_button_locs = [(x, y) for x, y in zip(db_x_list, db_y_list)]
        self.decay_hint_locs = [(x, -1 * y) for x, y in zip(db_x_list, db_y_list)]
        self.decoy_buttons_pressed = [False for i in range(len(self.decoy_button_locs))]

    def get_obs_for_agent(self, indx):
        loc = [self.agent0_loc, self.agent1_loc][indx]
        obs = list(loc)
        x,y = loc
        hint_loc = self.hint_locs[indx]
        # get hint if (at the hint or no hint location) and (button pressed or agent 0, who presses the button)
        if(self.num_decoy_buttons == 0 and self.decoy_buttons_same_hint == 0):
            agent_gets_hint = (hint_loc is None or x == hint_loc) and (self.button_pressed or indx == 0)
        else:
            agent0_on_decoy_hint = False
            agent0_loc = list(self.agent0_loc)
            for dh_loc in self.decay_hint_locs:
                if((agent0_loc[0] + 1) == dh_loc[0] and agent0_loc[1] == dh_loc[1]):
                    agent0_on_decoy_hint = True
            agent_gets_hint = ((hint_loc is None or x == hint_loc) and (self.button_pressed or indx == 0)) or (indx == 0 and agent0_on_decoy_hint)
        obs = obs + (self._hint_encoding() if agent_gets_hint else [0])
        if indx == 0: # button obs for agent 0
            obs = obs + ([1] if self.button_pressed else [-1]) # "if x == self.button_loc" for local obs of button
            if(self.num_decoy_buttons > 0):
                for flag in self.decoy_buttons_pressed:
                    obs = obs + ([1] if flag else [-1])
        else:
            obs = obs + [0] # to conform to the same len
            if(self.num_decoy_buttons > 0):
                for flag in self.decoy_buttons_pressed:
                    obs = obs + [0]
        return obs

    def get_obs(self):
        return np.array([self.get_obs_for_agent(0), self.get_obs_for_agent(1)])

    def get_state(self):
        if(self.num_decoy_buttons == 0):
            return np.array(list(self.agent0_loc) + list(self.agent1_loc) + self._hint_encoding() + ([1] if self.button_pressed else [0]))
        else:
            state = list(self.agent0_loc) + list(self.agent1_loc) + self._hint_encoding() + ([1] if self.button_pressed else [0])
            for flag in self.decoy_buttons_pressed:
                state = state + ([1] if flag else [0])
            return np.array(state)

    def get_state_action_tables(self, init_value = 0.0):
        # For tabular learning
        agent_0_table = {}
        agent_1_table = {}
        a0_len = self.lengths[0]
        a1_len = self.lengths[1]
        posssible_y = [-1, 0, 1]
        possible_toggle_values = [-1, 1]
        agent_0_cartesian_product_lists = [[x for x in range(a0_len)], posssible_y, possible_toggle_values, possible_toggle_values]
        agent_1_cartesian_product_lists = [[x for x in range(a1_len)], posssible_y, possible_toggle_values, possible_toggle_values]
        if(self.num_decoy_buttons > 0):
            for dbutton_idx in range(self.num_decoy_buttons):
                agent_0_cartesian_product_lists.append(possible_toggle_values)
                agent_1_cartesian_product_lists.append(possible_toggle_values)
        # agent_0_cartesian_product_lists.append(ACTIONS)
        # agent_1_cartesian_product_lists.append(ACTIONS)
        # Create agent 0 table
        for state in itertools.product(*agent_0_cartesian_product_lists):
            agent_0_table[state] = np.full(np.array(ACTIONS).shape, init_value)
        # Create agent 1 table
        for state in itertools.product(*agent_1_cartesian_product_lists):
            agent_1_table[state] = np.full(np.array(ACTIONS).shape, init_value)
        return agent_0_table, agent_1_table

    def get_obs_size(self):
        return len(self.get_obs_for_agent(0))

    def get_state_size(self):
        return len(self.get_state())

    def get_avail_actions(self):
        return np.array([ACTIONS, ACTIONS])

    def get_total_actions(self):
        return len(ACTIONS)

    def get_obs_agent(self, agent_id):
        return np.array(self.get_obs_for_agent(agent_id))

    def get_avail_agent_actions(self, agent_id):
        return ACTIONS

    def close(self):
        pass

    def get_stats(self):
        pass

    def seed(self):
        raise NotImplementedError

    def render(self):
        # agent 0 (on left:)
        a0line1 = ["." for _ in range(self.lengths[0])] + ["."] + ["."]
        a0line2 = ["."] + [" " for _ in range(self.lengths[0])] + ["."]
        a0line3 = ["." for _ in range(self.lengths[0])] + ["."] + ["."]
        if self.hint_locs[0] is None:
            a0line1[self.starts[0]+1] = "u" if self.hint == UP else "d"
        else:
            a0line2[self.hint_locs[0]+1] = "u" if self.hint == UP else "d"
        a0line2[self.button_loc+1] = "b"
        if(self.num_decoy_buttons > 0):
            line_dict = {-1: a0line3, 0: a0line2, 1: a0line1}
            for db_i in range(len(self.decoy_button_locs)):
                line_dict[self.decoy_button_locs[db_i][1]][self.decoy_button_locs[db_i][0]] = "b"
                if(self.decoy_buttons_pressed[db_i]):
                    if(self.decoy_buttons_same_hint == 0):
                        line_dict[self.decoy_button_locs[db_i][1] * -1][self.decoy_button_locs[db_i][0]] = "h"
                    else:
                        line_dict[self.decoy_button_locs[db_i][1] * -1][self.decoy_button_locs[db_i][0]] = "u" if self.hint == UP else "d"
            line_dict[self.agent0_loc[1]][self.agent0_loc[0]+1] = "a"
        else:
            a0line2[self.agent0_loc[0]+1] = "a"
        # agent 1 (on right:)
        a1line1 = ["." for _ in range(self.lengths[1])] + [" "] + ["."]
        a1line2 = ["."] + [" " for _ in range(self.lengths[1])] + ["."]
        a1line3 = ["." for _ in range(self.lengths[1])] + [" "] + ["."]
        if self.button_pressed:
            a1line2[self.hint_locs[1]+1] = "u" if self.hint == UP else "d"
        a1_line = [a1line3, a1line2, a1line1][self.agent1_loc[1]+1]
        a1_line[self.agent1_loc[0]+1] = "a"
        # print
        print("".join(a0line1)+"".join(a1line1))
        print("".join(a0line2)+"".join(a1line2))
        print("".join(a0line3)+"".join(a1line3))







if __name__ == "__main__":

    # Degenerate (TMaze)
    # d = {
    #     "lengths":(1,15),
    #     "starts": (0,7),
    #     "hint_locs":(None,0), # hint is always visible
    #     "button_loc":0,
    #     "episode_limit": 30,
    #     "right_r": 1,
    #     "wrong_r":-.5,
    #     "num_decoy_buttons": 0,
    #     "decoy_buttons_fixed": 1,
    #     "decoy_buttons_same_hint": 0,
    # }
    # Find Button
    # d = {
    #     "lengths":(7,15),
    #     "starts": (3,7),
    #     "hint_locs":(None,0), # hint is always visible
    #     "button_loc":6,
    #     "episode_limit": 30,
    #     "right_r": 1,
    #     "wrong_r":-.5,
    #     "num_decoy_buttons": 0,
    #     "decoy_buttons_fixed": 1,
    #     "decoy_buttons_same_hint": 0,
    # }
    # # Find Button and Hint (hard)
    # d = {
    #     "lengths":(7,15),
    #     "starts": (3,7),
    #     "hint_locs":(0,0), # hint is always visible
    #     "button_loc":6,
    #     "episode_limit": 30,
    #     "right_r": 1,
    #     "wrong_r":-.5,
    #     "num_decoy_buttons": 0,
    #     "decoy_buttons_fixed": 1,
    #     "decoy_buttons_same_hint": 0,
    # }
    # # Find Button and Hint (easy)
    # d = {
    #     "lengths":(7,15),
    #     "starts": (3,7),
    #     "hint_locs":(6,0), # hint is always visible
    #     "button_loc":5,
    #     "episode_limit": 30,
    #     "right_r": 1,
    #     "wrong_r":-.5,
    #     "num_decoy_buttons": 0,
    #     "decoy_buttons_fixed": 1,
    #     "decoy_buttons_same_hint": 0,
    # }

    # Find Button with Decoy Buttons
    d = {
        "lengths":(7,15),
        "starts": (3,7),
        "hint_locs":(None,0), # hint is always visible
        "button_loc":6,
        "episode_limit": 30,
        "right_r": 1,
        "wrong_r":-.5,
        "num_decoy_buttons": 2,
        "decoy_buttons_fixed": 1,
        "decoy_buttons_same_hint": 0,
    }

    # Find Button and Hint with Decoy Buttons having the same hint
    # d = {
    #     "lengths":(7,15),
    #     "starts": (3,7),
    #     "hint_locs":(0,0), # hint is always visible
    #     "button_loc":6,
    #     "episode_limit": 30,
    #     "right_r": 1,
    #     "wrong_r":-.5,
    #     "num_decoy_buttons": 2,
    #     "decoy_buttons_fixed": 1,
    #     "decoy_buttons_same_hint": 1,
    # }

    env = CMaze(env_args=d)
    env.get_state_action_tables()
    num_ep = 2
    for i in range(num_ep):
        print("ep:", i+1)
        obs, state = env.reset()
        print("\nobs:", obs)
        print("state:", state)
        env.render()
        done = False
        while not done:
            action_map = {"a": LEFT,
                          "d": RIGHT,
                          "w": UP,
                          "s": DOWN,
                          "n": NOOP,}
            action0 = None
            while action0 not in action_map:
                print("Action0?")
                print("actions allowed: w,a,s,d,n")
                action0 = input()
            action1 = None
            while action1 not in action_map:
                print("Action1?")
                print("actions allowed: w,a,s,d,n")
                action1 = input()
            # reward, done, info = env.step(action_map[action]) # single agent
            reward, done, info = env.step([action_map[action0], action_map[action1]])
            obs, state = env.get_obs(), env.get_state()
            print()
            env.render()
            print("obs:", obs)
            print("state:", state)
            print("done:", done)
            print("info:", info)
            print("reward:", reward)
        print("\nEND\n")
